Skip to content

(WIP) Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU#434

Draft
amepas wants to merge 4 commits into
mainfrom
flux2klein-onboarding
Draft

(WIP) Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU#434
amepas wants to merge 4 commits into
mainfrom
flux2klein-onboarding

Conversation

@amepas

@amepas amepas commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

(WIP) - this will be updated with multi-chip latency and support for Flux2 Klein 9B!

Draft PR for the Flux2 Klein model. Includes a custom implementation of the Qwen3-4B model for getting text embeddings. VAE Decoder, RoPE positional embedder, flow-matching step schedule are all re-used. Light modifications to transformer/attention blocks are used.

Latency for batch-size 4 of 1024 by 1024 images (bfloat16):

  • Prompt Encoding (Qwen3): 57.67 ms (1.58% of total)
  • Denoising Loop (Flux 4 steps): 3,181.20 ms (87.09% of total)
    • Per-Step Transformer Time: 795.30 ms
  • VAE Decoding (VAE): 413.77 ms (11.33% of total)
  • Total: 3.65 seconds

PR includes code for verifying accuracy of implementation. Sharding model is implemented but not tested.

Image generation is only supported so far.

@amepas amepas requested review from chandrasekhard2 and eltsai June 29, 2026 20:49
@amepas amepas requested a review from entrpn as a code owner June 29, 2026 20:49
@github-actions

Copy link
Copy Markdown

@amepas amepas marked this pull request as draft June 29, 2026 20:53

class GenerateFlux2KleinE2ETest(unittest.TestCase):

def test_end_to_end_parity_and_offloading(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is very likely to fail in the github runner with the hardcoded values. We usually don't run e2e tests on the github runner, you can mark it so it doesn't run in the runner.

every single stage against the golden PyTorch reference.
"""
# Set highest precision for strict mathematical parity checks
jax.config.update("jax_default_matmul_precision", "highest")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason for using highest here?

@amepas amepas changed the title Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU (WIP) Cleaned Flux2 Klein Implementation, with benchmarking done on v6 TPU Jun 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants